{
  "cells": [
    {
      "metadata": {
        "id": "-KkvqLgjiIdD"
      },
      "cell_type": "markdown",
      "source": [
        "# Tool Use\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/tool_use.ipynb)\n",
        "\n",
        "Demo to show how to use tool-use with Gemma library.\n",
        "\n",
        "Note: The Gemma 1, 2 and 3 models were not specifically trained for tool use. This is more a proof-of-concept than an officially supported feature."
      ]
    },
    {
      "metadata": {
        "id": "gcNRfVEnj4aq"
      },
      "cell_type": "code",
      "source": [
        "!pip install -q gemma"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "executionInfo": {
          "elapsed": 2221,
          "status": "ok",
          "timestamp": 1749202985345,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "k1ZAgLg1j9NT"
      },
      "cell_type": "code",
      "source": [
        "# Common imports\n",
        "import os\n",
        "import datetime\n",
        "\n",
        "# Gemma imports\n",
        "from gemma import gm"
      ],
      "outputs": [],
      "execution_count": 3
    },
    {
      "metadata": {
        "id": "139lZszJj_CC"
      },
      "cell_type": "markdown",
      "source": [
        "By default, Jax does not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
      ]
    },
    {
      "metadata": {
        "executionInfo": {
          "elapsed": 2,
          "status": "ok",
          "timestamp": 1749138071985,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "VtlWWLIYj_LJ"
      },
      "cell_type": "code",
      "source": [
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
      ],
      "outputs": [],
      "execution_count": 2
    },
    {
      "metadata": {
        "id": "31JPZb5RkD_p"
      },
      "cell_type": "markdown",
      "source": [
        "Load the model and the params."
      ]
    },
    {
      "metadata": {
        "executionInfo": {
          "elapsed": 39057,
          "status": "ok",
          "timestamp": 1749203024713,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "RsAo6k4_kEJS",
        "outputId": "e10afb5c-6c81-42e8-e590-a39ea4ef3bf7"
      },
      "cell_type": "code",
      "source": [
        "model = gm.nn.Gemma3_4B()\n",
        "\n",
        "params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)"
      ],
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "INFO:2025-06-06 02:43:16,896:jax._src.xla_bridge:749: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'\n",
            "INFO:2025-06-06 02:43:16,897:jax._src.xla_bridge:749: Unable to initialize backend 'proxy': INVALID_ARGUMENT: IFRT proxy server address must be '\u003ctransport-type\u003e://\u003cbackend-address\u003e' (e.g., 'grpc://localhost'), but got \n",
            "INFO:2025-06-06 02:43:16,900:jax._src.xla_bridge:749: Unable to initialize backend 'mlcr': Could not initialize backend 'mlcr'\n",
            "INFO:2025-06-06 02:43:16,901:jax._src.xla_bridge:749: Unable to initialize backend 'sliceme': Could not initialize backend 'sliceme'\n"
          ]
        }
      ],
      "execution_count": 4
    },
    {
      "metadata": {
        "id": "p108c5yIlYH7"
      },
      "cell_type": "markdown",
      "source": [
        "## Using existing tools\n",
        "\n",
        "If you're familiar with the [sampling](https://gemma-llm.readthedocs.io/en/latest/sampling.html) tutorial, using tool-use differ in two ways:\n",
        "\n",
        "1. Using the `gm.text.ToolSampler` rather than the `gm.text.ChatSampler`.\n",
        "2. Passing the `tools=` you want to use to the sampler.\n",
        "\n",
        "For example:"
      ]
    },
    {
      "metadata": {
        "colab": {
          "height": 594
        },
        "executionInfo": {
          "elapsed": 50615,
          "status": "ok",
          "timestamp": 1749138791069,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "iRCV5h8BlVX6",
        "outputId": "b3b5d83d-8a8b-4982-fc8f-d409fb8b38a9"
      },
      "cell_type": "code",
      "source": [
        "sampler = gm.text.ToolSampler(\n",
        "    model=model,\n",
        "    params=params,\n",
        "    tools=[\n",
        "        gm.tools.Calculator(),\n",
        "        gm.tools.FileExplorer(),\n",
        "    ],\n",
        "    print_stream=True,\n",
        ")\n",
        "\n",
        "output = sampler.chat('I have a serie `Sn+1 = cos(Sn) * 2`. Using the calculator, compute the steps 0-4 for S0 = 3')"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Thought: I need to compute S0, S1, S2, S3, and S4 using the given recursive formula Sn+1 = cos(Sn) * 2 and S0 = 3.\n",
            "Let's start with S0 = 3.\n",
            "S1 = cos(S0) * 2 = cos(3) * 2\n",
            "S2 = cos(S1) * 2 = cos(cos(3) * 2) * 2\n",
            "S3 = cos(S2) * 2 = cos(cos(cos(3) * 2) * 2) * 2\n",
            "S4 = cos(S3) * 2 = cos(cos(cos(cos(3) * 2) * 2)) * 2\n",
            "\n",
            "I will use the calculator to compute these values.\n",
            "{\"tool_name\": \"calculator\", \"expression\": \"cos(3) * 2\"}\n",
            "\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Tool result: -1.9799849932008908]\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Thought: Now I need to compute S1 = cos(S0) * 2 = cos(3) * 2 = -1.9799849932008908 * 2\n",
            "{\"tool_name\": \"calculator\", \"expression\": \"-1.9799849932008908 * 2\"}\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Tool result: -3.9599699864017817]\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Thought: Now I need to compute S2 = cos(S1) * 2 = cos(-3.9599699864017817) * 2\n",
            "{\"tool_name\": \"calculator\", \"expression\": \"cos(-3.9599699864017817) * 2\"}\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Tool result: -1.3668134299076982]\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Thought: Now I need to compute S3 = cos(S2) * 2 = cos(-1.3668134299076982) * 2\n",
            "{\"tool_name\": \"calculator\", \"expression\": \"cos(-1.3668134299076982) * 2\"}\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Tool result: 0.4051424976130353]\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Thought: Now I need to compute S4 = cos(S3) * 2 = cos(0.4051424976130353) * 2\n",
            "{\"tool_name\": \"calculator\", \"expression\": \"cos(0.4051424976130353) * 2\"}\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Tool result: 1.8380924822033438]\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "The values are: S0 = 3, S1 = -3.9599699864017817, S2 = -1.3668134299076982, S3 = 0.4051424976130353, S4 = 1.8380924822033438"
          ]
        }
      ],
      "execution_count": 10
    },
    {
      "metadata": {
        "id": "FAI54F-Blkan"
      },
      "cell_type": "markdown",
      "source": [
        "Note: Only the final model answer is returned. You can access the conversation history, including all intermediates tool calls and output through `sampler.turns` property."
      ]
    },
    {
      "metadata": {
        "id": "D0_IIS1Nlfuw"
      },
      "cell_type": "markdown",
      "source": [
        "## Creating your own tool\n",
        "\n",
        "To create your own tool, you can inherit from the `gm.tools.Tool` class. You should provide:\n",
        "\n",
        "* A description \u0026 example, so the model knows how to use your tool\n",
        "* Implement the `call` method. The `call` function can take arbitrary `**kwargs`, but the name of the args should match the ones defined in `tool_kwargs` and `tool_kwargs_doc`"
      ]
    },
    {
      "metadata": {
        "executionInfo": {
          "elapsed": 55,
          "status": "ok",
          "timestamp": 1749203934196,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "XqmQcfdI0oEl"
      },
      "cell_type": "code",
      "source": [
        "class DateTime(gm.tools.Tool):\n",
        "  \"\"\"Tool to access the current date.\"\"\"\n",
        "\n",
        "  DESCRIPTION = 'Access the current date, time,...'\n",
        "  EXAMPLE = gm.tools.Example(\n",
        "      query='Which day of the week are we today ?',\n",
        "      thought='The `datetime.strptime` uses %a for day of the week',\n",
        "      tool_kwargs={'format': '%a'},\n",
        "      tool_kwargs_doc={'format': '\u003cANY datetime.strptime expression\u003e'},\n",
        "      result='Sat',\n",
        "      answer='Today is Saturday.',\n",
        "  )\n",
        "\n",
        "  def call(self, format: str) -\u003e str:\n",
        "    dt = datetime.datetime.now()\n",
        "    return dt.strftime(format)\n"
      ],
      "outputs": [],
      "execution_count": 7
    },
    {
      "metadata": {
        "id": "sSxYhXPuuXYp"
      },
      "cell_type": "markdown",
      "source": [
        "The tool can then be used in the sampler:"
      ]
    },
    {
      "metadata": {
        "colab": {
          "height": 118
        },
        "executionInfo": {
          "elapsed": 2156,
          "status": "ok",
          "timestamp": 1749204833094,
          "user": {
            "displayName": "",
            "userId": ""
          },
          "user_tz": -120
        },
        "id": "9S8xB2B-0cbW",
        "outputId": "fccc0e89-e922-4184-8b77-800041cdd77e"
      },
      "cell_type": "code",
      "source": [
        "sampler = gm.text.ToolSampler(\n",
        "    model=model,\n",
        "    params=params,\n",
        "    tools=[\n",
        "        DateTime(),\n",
        "    ],\n",
        "    print_stream=True,\n",
        ")\n",
        "\n",
        "output = sampler.chat('Which date are we today ?')"
      ],
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Thought: I need to get the current date.\n",
            "{\"tool_name\": \"datetime\", \"format\": \"%Y-%m-%d\"}\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Tool result: 2025-06-06]\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\u003chr\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Today is June 6th, 2025."
          ]
        }
      ],
      "execution_count": 9
    },
    {
      "metadata": {
        "id": "esIpCjhxzHmf"
      },
      "cell_type": "markdown",
      "source": [
        "## Next steps\n",
        "\n",
        "* See our [multimodal](https://gemma-llm.readthedocs.io/en/latest/multimodal.html) example to query the model with images.\n",
        "* See our [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example to train Gemma on your custom task.\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {},
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
